{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "f7d979c1",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append('../..')\n",
    "import os\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = '0'\n",
    "os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'\n",
    "os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'\n",
    "import tensorflow as tf\n",
    "tf.get_logger().setLevel('ERROR')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "dafa4c11",
   "metadata": {},
   "outputs": [],
   "source": [
    "import ampligraph\n",
    "# Benchmark datasets are under ampligraph.datasets module\n",
    "from ampligraph.datasets import load_fb15k_237\n",
    "# load fb15k-237 dataset\n",
    "dataset = load_fb15k_237()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "87ed1243",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Metal device set to: Apple M1 Pro\n",
      "\n",
      "systemMemory: 32.00 GB\n",
      "maxCacheSize: 10.67 GB\n",
      "\n",
      "Epoch 1/10\n",
      "29/29 [==============================] - 3s 101ms/step - loss: 6625.2280\n",
      "Epoch 2/10\n",
      "3/3 [==============================] - 1s 418ms/steposs: 6505.03\n",
      "29/29 [==============================] - 3s 92ms/step - loss: 6505.0376 - val_mrr: 0.0576 - val_mr: 3106.6989 - val_hits@1: 0.0000e+00 - val_hits@10: 0.1562 - val_hits@100: 0.3352\n",
      "WARNING - Found untraced functions such as _get_ranks while saving (showing 1 of 1). These functions will not be directly callable after loading.\n",
      "Epoch 3/10\n",
      "29/29 [==============================] - 1s 49ms/step - loss: 6386.7886\n",
      "Epoch 4/10\n",
      "3/3 [==============================] - 0s 101ms/steposs: 6267.646\n",
      "29/29 [==============================] - 2s 63ms/step - loss: 6267.6465 - val_mrr: 0.0673 - val_mr: 2225.0398 - val_hits@1: 0.0000e+00 - val_hits@10: 0.1903 - val_hits@100: 0.3835\n",
      "Epoch 5/10\n",
      "29/29 [==============================] - 1s 47ms/step - loss: 6149.4019\n",
      "Epoch 6/10\n",
      "3/3 [==============================] - 0s 118ms/steposs: 6029.96\n",
      "29/29 [==============================] - 2s 63ms/step - loss: 6029.9697 - val_mrr: 0.0765 - val_mr: 1584.9233 - val_hits@1: 0.0000e+00 - val_hits@10: 0.1989 - val_hits@100: 0.4119\n",
      "Epoch 7/10\n",
      "29/29 [==============================] - 1s 51ms/step - loss: 5910.1436\n",
      "Epoch 8/10\n",
      "3/3 [==============================] - 0s 126ms/steposs: 5814.78\n",
      "29/29 [==============================] - 2s 59ms/step - loss: 5789.8135 - val_mrr: 0.0732 - val_mr: 1316.3835 - val_hits@1: 0.0000e+00 - val_hits@10: 0.2045 - val_hits@100: 0.4489\n",
      "Epoch 9/10\n",
      "29/29 [==============================] - 1s 43ms/step - loss: 5670.5464\n",
      "Epoch 10/10\n",
      "3/3 [==============================] - 1s 186ms/steposs: 5569.24\n",
      "29/29 [==============================] - 2s 70ms/step - loss: 5552.7808 - val_mrr: 0.0728 - val_mr: 1146.0284 - val_hits@1: 0.0000e+00 - val_hits@10: 0.2131 - val_hits@100: 0.4801\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<tensorflow.python.keras.callbacks.History at 0x2931fd2a0>"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Import the KGE model\n",
    "from ampligraph.latent_features import ScoringBasedEmbeddingModel\n",
    "\n",
    "# create the model with transe scoring function\n",
    "model = ScoringBasedEmbeddingModel(eta=1, \n",
    "                                     k=10,\n",
    "                                     scoring_type='TransE')\n",
    "\n",
    "\n",
    "# compile the model with loss and optimizer\n",
    "model.compile(optimizer='adam', loss='multiclass_nll')\n",
    "\n",
    "# Use this for checkpoints at regular intervals\n",
    "checkpoint = tf.keras.callbacks.ModelCheckpoint('./chkpt1_transe', monitor='val_mrr', verbose=0, \n",
    "                                                save_best_only=True, mode='min')\n",
    "\n",
    "dataset = load_fb15k_237()\n",
    "\n",
    "model.fit(dataset['train'],\n",
    "          batch_size=10000,\n",
    "          epochs=10,\n",
    "          validation_freq=2,\n",
    "          validation_batch_size=100,\n",
    "          validation_burn_in=2,\n",
    "          validation_data = dataset['valid'][::100],\n",
    "          callbacks=[checkpoint])     # Pass the callback to the fit function\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "e5b4dc3b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "206/206 [==============================] - 407s 2s/step\n",
      "MR: 921.3923084450533\n",
      "MRR: 0.17730360323410402\n",
      "hits@1: 0.12151384675604267\n",
      "hits@10: 0.2851551032390645\n"
     ]
    }
   ],
   "source": [
    "# evaluate on the test set\n",
    "ranks = model.evaluate(dataset['test'], # test set\n",
    "                       batch_size=100, # evaluation batch size\n",
    "                       corrupt_side='s,o', \n",
    "                       use_filter={'train':dataset['train'], # Filter to be used for evaluation\n",
    "                                   'valid':dataset['valid'],\n",
    "                                   'test':dataset['test']}\n",
    "                       )\n",
    "\n",
    "# import the evaluation metrics\n",
    "from ampligraph.evaluation.metrics import mrr_score, hits_at_n_score, mr_score\n",
    "\n",
    "print('MR:', mr_score(ranks))\n",
    "print('MRR:', mrr_score(ranks))\n",
    "print('hits@1:', hits_at_n_score(ranks, 1))\n",
    "print('hits@10:', hits_at_n_score(ranks, 10))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "fee12a7d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING - Found untraced functions such as _get_ranks while saving (showing 1 of 1). These functions will not be directly callable after loading.\n"
     ]
    }
   ],
   "source": [
    "from ampligraph.utils import save_model\n",
    "# explictly save the model\n",
    "save_model(model, 'saved_model_transE')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "adc72ebd",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Saved model does not include a db file. Skipping.\n"
     ]
    }
   ],
   "source": [
    "from ampligraph.utils import restore_model\n",
    "\n",
    "# restore saved models or checkpoints\n",
    "model = restore_model('saved_model_transE')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "115f311e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "206/206 [==============================] - 254s 1s/step\n",
      "MR: 921.3923084450533\n",
      "MRR: 0.17730360323410402\n",
      "hits@1: 0.12151384675604267\n",
      "hits@10: 0.2851551032390645\n"
     ]
    }
   ],
   "source": [
    "# evaluate on the test set\n",
    "ranks = model.evaluate(dataset['test'], # test set\n",
    "                       batch_size=100, # evaluation batch size\n",
    "                       corrupt_side='s,o', \n",
    "                       use_filter={'train':dataset['train'], # Filter to be used for evaluation\n",
    "                                   'valid':dataset['valid'],\n",
    "                                   'test':dataset['test']}\n",
    "                       )\n",
    "\n",
    "# import the evaluation metrics\n",
    "from ampligraph.evaluation.metrics import mrr_score, hits_at_n_score, mr_score\n",
    "\n",
    "print('MR:', mr_score(ranks))\n",
    "print('MRR:', mrr_score(ranks))\n",
    "print('hits@1:', hits_at_n_score(ranks, 1))\n",
    "print('hits@10:', hits_at_n_score(ranks, 10))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.10.6 ('base')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.6"
  },
  "vscode": {
   "interpreter": {
    "hash": "2e69f3670cdad0193847aaa0b77be56c05c951fcbdd384ff882dde0464f4de76"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
